from typing import Tuple, Optional
from torch.nn import LSTM, ModuleList

from BaselineModel.BaselineBaseLSTMRegressor import BaseLSTMRegressor
from Utils.Constants import Diff


class LSTMRegressor(BaseLSTMRegressor):
    def __init__(self, embedding_size: int, hidden_size: int, lstm_layers: int, inner_dense_layer_sizes: Tuple[int],
                 tb_log_path: Optional[str], bi_directional: bool, sequence_len: int = Diff.NUMBER_STEPS_SAVED,
                 batch_first: bool = True, last_layer: Optional[str] = 'sigmoid', **kwargs):
        """
        Create an LSTM with few final dense layer for regression purposes
        :param embedding_size: size of input embeddings
        :param hidden_size: lstm hidden size
        :param lstm_layers:
        :param inner_dense_layer_sizes: tuple of final dense layers sizes
        :param tb_log_path: path for saving tensorboard
        :param sequence_len: expected size of sequences
        :param batch_first: True if the input first dimension is the batch size
        :param last_layer: name of last layer type. supported types: sigmoid / relu
        """
        is_double = True
        super(LSTMRegressor, self).__init__(embedding_size=embedding_size, hidden_size=hidden_size, is_double=is_double,
                                            lstm_layers=lstm_layers, last_layer=last_layer, tb_log_path=tb_log_path,
                                            inner_dense_layer_sizes=inner_dense_layer_sizes, sequence_len=sequence_len,
                                            batch_first=batch_first, kwargs=kwargs, bi_directional=bi_directional)

    def _build_model(self, embedding_size: int, hidden_size: int, lstm_layers: int, inner_dense_layer_sizes: Tuple[int],
                     last_layer: str, batch_first: bool, bi_directional: bool, **kwargs):
        self._lstm = LSTM(input_size=embedding_size, hidden_size=hidden_size, num_layers=lstm_layers,
                          batch_first=batch_first, bidirectional=bi_directional)

        dense_input_size = hidden_size * 2 if bi_directional else hidden_size
        dense_layers, last_size = self._build_dense_layers(dense_input_size, inner_dense_layer_sizes)
        self._dense_layers = ModuleList(dense_layers)
        self._dense_layers.extend(self._build_last_layer(last_size, last_layer))

    def forward(self, inp):
        lstm_out, _ = self._lstm(inp)
        x = lstm_out[:, -1, :]       # get the final hidden layer value
        for layer in self._dense_layers:
            x = layer(x)
        return x
